{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "64fd87eb",
   "metadata": {},
   "source": [
    "# 蛇棋环境搭建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90c80123",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gym\n",
    "from gym.spaces import Discrete\n",
    "\n",
    "from contextlib import contextmanager\n",
    "import time\n",
    "\n",
    "@contextmanager\n",
    "def timer(name):\n",
    "    start = time.time()\n",
    "    yield\n",
    "    end = time.time()\n",
    "    print('{} COST:{}'.format(name, end-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "49f59f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SnakeEnv(gym.Env):\n",
    "    SIZE=100\n",
    "    \n",
    "    def __init__(self, ladder_num, dices):\n",
    "        self.ladder_num = ladder_num\n",
    "        self.dices = dices\n",
    "        self.observation_space = Discrete(self.SIZE+1)\n",
    "        self.action_space = Discrete(len(dices))\n",
    "        \n",
    "        if ladder_num == 0:\n",
    "            self.ladders = {0:0}\n",
    "        else:\n",
    "            # 处理梯子值，让梯子的数值无重复地反向赋值\n",
    "            ladders = set(np.random.randint(1, self.SIZE, size=self.ladder_num*2))\n",
    "            while len(ladders) < self.ladder_num*2:\n",
    "                ladders.add(np.random.randint(1, self.SIZE))\n",
    "\n",
    "            ladders = list(ladders)\n",
    "            ladders = np.array(ladders)\n",
    "            np.random.shuffle(ladders)\n",
    "            ladders = ladders.reshape((self.ladder_num,2))\n",
    "\n",
    "            re_ladders = list()\n",
    "            for i in ladders:\n",
    "                re_ladders.append([i[1],i[0]])\n",
    "\n",
    "            re_ladders = np.array(re_ladders)\n",
    "            # dict()可以把nx2维数组转化为字典形式\n",
    "            self.ladders = dict(np.append(re_ladders, ladders, axis=0))\n",
    "        print(f'ladders info:{self.ladders} dice ranges:{self.dices}')\n",
    "        self.pos = 1\n",
    "        \n",
    "    def reset(self):\n",
    "        self.pos = 1\n",
    "        return self.pos\n",
    "    \n",
    "    def step(self, a):\n",
    "        step = np.random.randint(1, self.dices[a]+1)\n",
    "        self.pos += step\n",
    "        if self.pos == 100:\n",
    "            return 100, 100, 1, {}\n",
    "        elif self.pos > 100:\n",
    "            self.pos = 200 - self.pos\n",
    "            \n",
    "        if self.pos in self.ladders:\n",
    "            self.pos = self.ladders[self.pos]\n",
    "        return self.pos, -1, 0, {}\n",
    "    \n",
    "    def reward(self, s):\n",
    "        if s == 100:\n",
    "            return 100\n",
    "        else:\n",
    "            return -1\n",
    "    \n",
    "    # 无渲染\n",
    "    def render(self):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b31cf9c",
   "metadata": {},
   "source": [
    "# 智能体构建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ad3343eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelFreeAgent(object):\n",
    "    def __init__(self, env):\n",
    "        self.s_len = env.observation_space.n\n",
    "        self.a_len = env.action_space.n\n",
    "        \n",
    "        self.pi = np.zeros(self.s_len, dtype=int)\n",
    "        self.value_q = np.zeros((self.s_len, self.a_len))\n",
    "        self.value_n = np.zeros((self.s_len, self.a_len))\n",
    "        self.gamma = 0.8\n",
    "    \n",
    "    \n",
    "    def play(self, state, epsilon=0.0):\n",
    "        # epsilon代表探索的概率，如果在epsilon覆盖范围内则会随机返回一个action（代表探索），否则返回目前已知\n",
    "        # 的最好策略\n",
    "        if np.random.rand() < epsilon:\n",
    "            return np.random.randint(self.a_len)\n",
    "        else:\n",
    "            return self.pi[state]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff1563de",
   "metadata": {},
   "source": [
    "# 策略评估（reward计算）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c2bb8ba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_game(env, agent):\n",
    "    state = env.reset()\n",
    "    total_reward = 0\n",
    "    state_action = []\n",
    "    \n",
    "    while True:\n",
    "        act = agent.play(state)\n",
    "        state_action.append((state,act))\n",
    "        state, reward, done, _ = env.step(act)\n",
    "        total_reward += reward\n",
    "        if done:\n",
    "            break\n",
    "    \n",
    "    return total_reward, state_action"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "660319c2",
   "metadata": {},
   "source": [
    "# 算法"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707725c3",
   "metadata": {},
   "source": [
    "## 1. Monte Carlo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "766657f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MonteCarlo(object):\n",
    "    def __init__(self, epsilon=0.0):\n",
    "        self.epsilon = epsilon\n",
    "            \n",
    "    \n",
    "    def monte_carlo_eval(self, agent, env):\n",
    "        state = env.reset()\n",
    "        episode = []\n",
    "        while True:\n",
    "            ac = agent.play(state, self.epsilon)\n",
    "            next_state, reward, done, _ = env.step(ac)\n",
    "            episode.append((state, ac, reward))\n",
    "            state = next_state\n",
    "            if done:\n",
    "                break\n",
    "        \n",
    "        value = []\n",
    "        return_val = 0\n",
    "        for item in reversed(episode):\n",
    "            # return_val 当前状态之后的所有回报乘以对应的打折率，求和\n",
    "            return_val = return_val*agent.gamma + item[2]\n",
    "            value.append((item[0], item[1], return_val))\n",
    "        \n",
    "        # 求迭代value_n次后的长期回报均值\n",
    "        for item in reversed(value):\n",
    "            agent.value_n[item[0]][item[1]] += 1\n",
    "            agent.value_q[item[0]][item[1]] += (item[2]-agent.value_q[item[0]][item[1]])/agent.value_n[item[0]][item[1]]\n",
    "    \n",
    "    \n",
    "    def policy_improve(self, agent):\n",
    "        # 如果用np.zeros(agent.pi)会报错\"ValueError: maximum supported dimension for an ndarray is 32, found 101\"\n",
    "        new_policy = np.zeros_like(agent.pi)\n",
    "        for i in range(1, agent.s_len):\n",
    "            new_policy[i] = np.argmax(agent.value_q[i,:])\n",
    "        # 之前if缩进在后面，会导致pi数组基本全为0。因为在policy未完全更新前进行判断，导致提前退出函数\n",
    "        if np.all(np.equal(new_policy, agent.pi)):\n",
    "            return False\n",
    "        else:\n",
    "            agent.pi = new_policy\n",
    "            return True\n",
    "    \n",
    "    \n",
    "    def monte_carlo_opt(self, agent, env):\n",
    "        iteration = 0\n",
    "        while True:\n",
    "            iteration += 1\n",
    "            for i in range(100):\n",
    "                self.monte_carlo_eval(agent, env)\n",
    "            ret = self.policy_improve(agent)\n",
    "            if not ret:\n",
    "                break\n",
    "        print('Monte Carlo: {} rounds'.format(iteration))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a170fa3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def monte_carlo_demo():\n",
    "    env = SnakeEnv(10, [3,6])\n",
    "    agent = ModelFreeAgent(env)\n",
    "    mc = MonteCarlo(0.05)\n",
    "    with timer('Timer Monte Carlo Iter'):\n",
    "        mc.monte_carlo_opt(agent, env)\n",
    "    print('return_pi={}'.format(eval_game(env, agent)))\n",
    "    print('agent.pi={}'.format(agent.pi))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c15a47d7",
   "metadata": {},
   "source": [
    "## 2. TD (Temporal Difference)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "13a458e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 改一处公式就是Q-Learning\n",
    "class SARSA(object):\n",
    "    def __init__(self, epsilon=0.0):\n",
    "        self.epsilon = epsilon\n",
    "    \n",
    "    \n",
    "    def sarsa_eval(self, agent, env):\n",
    "        state = env.reset()\n",
    "        prev_act = -1\n",
    "        prev_state = -1\n",
    "\n",
    "        while True:\n",
    "            act = agent.play(state, self.epsilon)\n",
    "            next_state, reward, done, _ = env.step(act)\n",
    "            if prev_act != -1:\n",
    "                # Q Learning与SARSA的区别就是np.max这里\n",
    "                # return_val = reward + agent.gamma * (0 if done else np.max(agent.value_q[state,:]))\n",
    "                return_val = reward + agent.gamma * (0 if done else agent.value_q[state][act])\n",
    "                agent.value_n[prev_state][prev_act] += 1\n",
    "                agent.value_q[prev_state][prev_act] += (return_val - agent.value_q[prev_state][prev_act]) / \\\n",
    "                                                        agent.value_n[prev_state][prev_act]\n",
    "            prev_act = act\n",
    "            prev_state = state\n",
    "            state = next_state\n",
    "            if done:\n",
    "                break\n",
    "    \n",
    "    \n",
    "    def policy_improve(self, agent):\n",
    "        new_policy = np.zeros_like(agent.pi)\n",
    "        for i in range(1, agent.s_len):\n",
    "            new_policy[i] = np.argmax(agent.value_q[i,:])\n",
    "        if np.all(np.equal(new_policy, agent.pi)):\n",
    "            return False\n",
    "        else:\n",
    "            agent.pi = new_policy\n",
    "            return True\n",
    "        \n",
    "        \n",
    "    def sarsa_opt(self, agent, env):\n",
    "        iteration = 0\n",
    "        while True:\n",
    "            iteration += 1\n",
    "            for i in range(100):\n",
    "                self.sarsa_eval(agent, env)\n",
    "            ret = self.policy_improve(agent)\n",
    "            if not ret:\n",
    "                break\n",
    "        print('SARSA: {} rounds'.format(iteration))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b5e172f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sarsa_demo():\n",
    "    env = SnakeEnv(10, [3,6])\n",
    "    agent = ModelFreeAgent(env)\n",
    "    sarsa_algo = SARSA(0.05)\n",
    "    with timer('Sarsa Iter'):\n",
    "        sarsa_algo.sarsa_opt(agent, env)\n",
    "    print('return_pi={}'.format(eval_game(env, agent)))\n",
    "    print('agent.pi={}'.format(agent.pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d889d042",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ladders info:{88: 94, 13: 41, 12: 82, 6: 84, 5: 78, 71: 50, 65: 93, 64: 10, 76: 61, 96: 9, 94: 88, 41: 13, 82: 12, 84: 6, 78: 5, 50: 71, 93: 65, 10: 64, 61: 76, 9: 96} dice ranges:[3, 6]\n",
      "Monte Carlo: 14 rounds\n",
      "Timer Monte Carlo Iter COST:0.28543591499328613\n",
      "return_pi=(92, [(1, 1), (7, 0), (96, 1), (98, 0), (99, 0), (99, 0), (99, 0), (99, 0), (98, 0)])\n",
      "agent.pi=[0 1 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 0 1 1 1 0 0 1 0 0 1 0\n",
      " 1 1 0 1 1 0 0 0 1 1 0 1 1 1 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1\n",
      " 0 1 1 1 1 1 1 0 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0]\n",
      "ladders info:{36: 51, 31: 34, 69: 83, 44: 91, 9: 53, 55: 19, 49: 47, 98: 4, 54: 67, 70: 22, 51: 36, 34: 31, 83: 69, 91: 44, 53: 9, 19: 55, 47: 49, 4: 98, 67: 54, 22: 70} dice ranges:[3, 6]\n",
      "SARSA: 45 rounds\n",
      "Sarsa Iter COST:1.673116683959961\n",
      "return_pi=(99, [(1, 0), (98, 0)])\n",
      "agent.pi=[0 0 0 0 1 1 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1\n",
      " 1 1 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 1\n",
      " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 0 1 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    monte_carlo_demo()\n",
    "    sarsa_demo()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
